#!/usr/bin/env python3

NUM_CLASSES = NUM_CLASSES = {
    "torch/cifar10": 10,
    "torch/cifar100": 100,
    "torch/imagenet": 1000,
}

VIT_MODEL_SIZE = {
    "tiny": {
        "depth": 12,
        "num_heads": 3,
        "embedding_dim": 192,
        "batch_size": 256,
        "leaning_rate": 2.5e-4,  # ($batch_size/512)*5.0e-4
    },
    "small": {
        "depth": 12,
        "num_heads": 6,
        "embedding_dim": 384,
        "batch_size": 256,
        "leaning_rate": 2.5e-4,  # ($batch_size/512)*5.0e-4
    },
    "base": {
        "depth": 12,
        "num_heads": 12,
        "embedding_dim": 768,
        "batch_size": 128,
        "leaning_rate": 1.25e-4,  # ($batch_size/512)*5.0e-4
    },
    "large": {
        "depth": 24,
        "num_heads": 16,
        "embed_dim": 1024,
        "batch_size": 32,
        "leaning_rate": 6.25e-5,  # ($batch_size/512)*5.0e-4
    },
}

GPT2_MODEL_SIZE = {
    "small": {
        "depth": 12,
        "num_heads": 12,
        "num_tokens": 1024,
        "embedding_dim": 768,
    },
    "medium": {
        "depth": 24,
        "num_heads": 16,
        "num_tokens": 1024,
        "embedding_dim": 1024,
    },
}
